 import argparse
import os


import pose_estimation._init_paths
from lib.core.config import config
from lib.core.config import update_config
import numpy as np


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default='/Your_File_Directory/exp.yaml',
                        type=str)
    parser.add_argument('--port', default=12590, type=int)

    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--gpus',
                        help='gpus',
                        type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--iteration',
                        help='the kth times of training',
                        type=int,
                        choices=range(1, 10),
                        default=1)

    args = parser.parse_args()

    return args


def reset_config(config, args):
    if args.gpus:
        config.GPUS = args.gpus
    if args.workers:
        config.WORKERS = args.workers


class UNI_TEST():
    def __init__(self, **kwargs):
        args = parse_args()
        reset_config(config, args)

        # self._Test_ResNet50_model()
        self._Test_ResNet50APS_model()
        # self._Test_MPIIDataset_UDP()


    def _Test_ResNet50_model(self):
        import torch
        import torchvision.transforms as transforms
        from lib.utils.vis import save_batch_image_with_joints, save_batch_heatmaps
        import lib.models as models
        import lib.dataset as dataset

        gpus = [int(i) for i in config.GPUS.split(',')]
        device = torch.device('cuda:0')
        model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
            config, is_train=True
        ).to(device)
        model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

        print('=> loading model from {}'.format(config.TEST.MODEL_FILE))
        # model.load_state_dict(torch.load(config.TEST.MODEL_FILE))

        model.eval()

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET[0].DATASET)(
            config,
            config.DATASET.TEST_DATASET[0],
            False,
            transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=5,
            shuffle=False,
            num_workers=int(config.WORKERS / len(gpus)),
            pin_memory=True
        )

        # inference
        with torch.no_grad():
            for i, (input_data, target, target_weight, meta) in enumerate(valid_loader):
                input_data = input_data.to(device, non_blocking=False)
                output = model(input_data)

                output_dir = './uni_test/_Test_ResNet50_model/'
                os.makedirs(output_dir, exist_ok=True)
                prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), 0)
                save_batch_heatmaps(
                    input_data, target, '{}_hm_gt.jpg'.format(prefix)
                )
                save_batch_heatmaps(
                    input_data, output, '{}_hm_pred.jpg'.format(prefix)
                )
                break
            print('_Test_ResNet50_model Done')

    def _Test_ResNet50APS_model(self):
        import torch
        import torchvision.transforms as transforms
        from lib.utils.vis import save_batch_image_with_joints, save_batch_heatmaps
        import lib.models as models
        import lib.dataset as dataset

        gpus = [int(i) for i in config.GPUS.split(',')]
        device = torch.device('cuda:0')
        model = models.pose_resnet_aps.get_pose_net(
            config, is_train=True
        ).to(device)
        model = torch.nn.DataParallel(model, device_ids=[0], output_device=0).cuda()

        model.eval()

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET[0].DATASET)(
            config,
            config.DATASET.TEST_DATASET[0],
            False,
            transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=2,
            shuffle=False,
            num_workers=int(config.WORKERS / len(gpus)),
            pin_memory=True
        )

        with torch.no_grad():
            for i, (input_data, target, target_weight, meta) in enumerate(valid_loader):
                # import pdb;pdb.set_trace()
                input_data = input_data.to(device, non_blocking=False)
                output = model(input_data)

                random_shift2 = np.array([1,0])
                shifted_image_batch2 = torch.roll(input_data, shifts=(random_shift2[0], random_shift2[1]), dims=(2, 3))
                shifted_output1 = model(shifted_image_batch2)

                output_dir = './uni_test/_Test_ResNet50APS_model/'
                os.makedirs(output_dir, exist_ok=True)
                prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), 0)
                save_batch_heatmaps(
                    input_data, target, '{}_hm_gt.jpg'.format(prefix)
                )
                save_batch_heatmaps(
                    input_data, output, '{}_hm_pred.jpg'.format(prefix)
                )
                break
            print('_Test_ResNet50APS_model Done')

    def _Test_MPIIDataset_UDP(self):
        import torchvision.transforms as transforms
        from lib.utils.utils import get_training_loader, get_training_set
        from lib.utils.vis import save_batch_image_with_joints, save_batch_heatmaps
        import torch.distributed as dist

        dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)

        epoch = 10

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        config['TRAIN']['BATCH_SIZE'] = 2
        train_dataset = get_training_set(config, transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        config.TRAIN.SHUFFLE = False
        train_loader, train_sampler = get_training_loader(train_dataset, config)
        train_sampler.set_epoch(epoch)
        for i, (input_data, target, target_weight, meta) in enumerate(train_loader):
            print(input_data.shape)
            print(target.shape)
            print(target_weight.shape)
            # print(meta)
            break
        import pdb;pdb.set_trace()
        output_dir = './uni_test/_Test_MPIIDataset_UDP/'
        prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
        save_batch_image_with_joints(
            input_data, meta['joints_2d_transformed'], meta['joints_vis'],
            '{}_gt.jpg'.format(prefix)
        )
        save_batch_heatmaps(
            input_data, target, '{}_hm_gt.jpg'.format(prefix)
        )
        print('_Test_MPIIDataset_UDP Done')

        return input_data, target, target_weight, meta



test = UNI_TEST()


